#include "NetworkProxy.h"

NAMESPACE_UPP

static bool sNetworkProxyTrace = false;
static bool sNetworkProxyTraceVerbose = false;;

#define LLOG(x)    	  do { if(sNetworkProxyTrace) RLOG(x); } while(false)
#define LDUMPHEX(x)	  do { if(sNetworkProxyTraceVerbose) RDUMPHEX(x); } while(false)
	
void NetworkProxy::Trace(bool verbose)
{
	sNetworkProxyTrace = true;
	sNetworkProxyTraceVerbose = verbose;
}

const char* NetworkProxy::GetErrorMessage(int code)
{
	static Tuple2<int, const char*> errors[] = {
		// NetProxy error messages.
		{ 10000,	t_("No client to serve (No socket attached).") },
		{ 10001,	t_("Proxy address or port not specified.") },
		{ 10002,	t_("Target address or port not specified.") },
		{ 10003,	t_("Couldn't resolve address.") },
		{ 10004,	t_("Couldn't connect to proxy server.") },
		{ 10005,	t_("Couldn't start SSL negotioation.") },
		{ 10006,	t_("Invalid packet received.") },
		{ 10007,	t_("Socket error occured.") },
		{ 10008,	t_("Operation was aborted.") },
		{ 10009,	t_("Connection timed out.") },
		// Http CONNECT method error message.
		{ 10010,	t_("Http CONNECT method failed.") },
		// SOCKS4 protocol error messages.
		{ 91,		t_("Request rejected or failed.") },
		{ 92,		t_("Request failed. Client is not running identd (or not reachable from the server).") },
		{ 93,		t_("Request failed. Cilent's identd could not confirm the user ID string in the request.") },
		{ 94,		t_("Socks4 protocol doesn't support IP version 6 address family. Considers using Socks5 protocol instead.") },
		// SOCKS5 protocol error messages.
		{ 1,		t_("General failure.") },
		{ 2,		t_("Connection not allowed by the ruleset.")},
		{ 3,		t_("Network unreachable.") },
		{ 4,		t_("Target machine unreachable.") },
		{ 5,		t_("Connection refused by the destination host.")},
		{ 6,		t_("TTL expired.") },
		{ 7,		t_("Command not supported / protocol error.") },
		{ 8,		t_("Address type not supported.") },
		{ 255,		t_("Invalid authentication method. No acceptable methods were offered.") },
		{ 256,		t_("Authentication failed.") },		
	};
	const Tuple2<int, const char *> *x = FindTuple(errors, __countof(errors), code);
	return x ? x->b : "-1";
}

bool NetworkProxy::SetError(int code)
{
	SetPhase(FAILED);
	error_desc = NetworkProxy::GetErrorMessage(code);
	if(error_desc.StartsWith("-1")) {
		error_desc = t_("Unknown error code.");
		error = -1;
	}
	else
		error = code;
	LLOG(Format("-- NetworkProxy error (code: %d): %s", error, error_desc));
	return false;
}

void NetworkProxy::StartDns(const String& host, int port)
{
	LLOG(Format("** Starting DNS sequence (locally) for: %s:%d", host, port));
	SetPhase(DNS);
	if(IsNull(socket->GetTimeout())) {
		if(socket->WhenWait) {
			ipaddrinfo.Start(host, port);
			while(ipaddrinfo.InProgress()) {
				Sleep(socket->GetWaitStep());
				socket->WhenWait();
				if(msecs(start_time) > timeout) {
					return;
				}
			}
		}
		else 
		if(!ipaddrinfo.Execute(host, port)) { 
			SetError(DNS_FAILED);
			return;
		}
		NextStep();
	}
	else 
		ipaddrinfo.Start(host, port);
}

void NetworkProxy::Dns()
{
	for(int i = 0; i <= Nvl(socket->GetTimeout(), INT_MAX); i++) {
		if(!ipaddrinfo.InProgress()) {
			NextStep();
			return;
		}
		Sleep(1);
	}
}

bool NetworkProxy::Put()
{
	for(;;) {
		int n = packet.GetCharCount() - packet_length;
		n = socket->Put(~packet + packet_length, n);
		if(n == 0) 
			break; 
		packet_length += n;

	}
	return packet_length < packet.GetCharCount();
}

bool NetworkProxy::Get()
{
	for(;;) {
		char c;
		if(socket->Get(&c, sizeof(char)) == 0)
			return true;
		else
			packet.Cat(c);
		if(IsEof())
			return false;
	}
}

void NetworkProxy::Init()
{
	if(socket == NULL) {
		SetError(NO_SOCKET_ATTACHED); 
		return;
	}
	if(proxy_host.IsEmpty() || !proxy_port) {
		SetError(HOST_NOT_SPECIFIED);
		return;
	}
	if(target_host.IsEmpty() || !target_port) {
		SetError(TARGET_NOT_SPECIFIED); 
		return;
	}
	socket->Clear();
	socket->GlobalTimeout(timeout);
	packet.Clear();
	packet_length = 0;
	error = 0;
	error_desc.Clear();
	start_time = msecs();
	NextStep = THISBACK(DoConnect);	
}

void NetworkProxy::CheckConnection()
{
	if(phase != FAILED) {
		if(socket->IsError()) 
			SetError(SOCKET_FAILURE);
		else
		if(socket->IsAbort()) 
			SetError(ABORTED);
		else
		if(msecs(start_time) >= timeout) 
	 		SetError(CONNECTION_TIMED_OUT);
	}
	if(phase == FAILED) {
		LLOG("-- Proxy connection failed.");
		ipaddrinfo.Clear();
		packet.Clear();
		if(socket->IsOpen())
			socket->Close();
	}
}

bool NetworkProxy::Connect(const String& host, int port)
{
	StartConnect(host, port);
	while(Do())
		;
	return IsSuccess();
}

void NetworkProxy::StartConnect(const String& host, int port)
{
	target_host = host;
	target_port = port;
	command = CONNECT;
	SetPhase(START);
}

bool NetworkProxy::Do()
{
	switch(phase) {
		case START:
			Start();
			break;
		case DNS:
			Dns();
			break;
		case REQUEST:
			if(Put())
				break;
			LLOG("++ Request succesfully sent.");
			packet.Clear();
			packet_length = 0;
			SetPhase(REPLY);
			break;
		case REPLY:
			if(Get())
				break;
			ParseReply();
			break;
		case STARTSSL:
			if(!socket->StartSSL())
				break;
			LLOG("++ SSL negotiation started.");			
			SetPhase(SSLHANDSHAKE);
		case SSLHANDSHAKE:
			if(socket->SSLHandshake())
				break;
			LLOG("++ SSL handshake successfull.");
		case FINISHED:
			SetPhase(SUCCESS);
			LLOG(Format("++ Succesfully %s %s:%d via proxy server (at %s:%d).",
					command == CONNECT ? "connected to" : "accepted a connection from",
					target_host, target_port, proxy_host, proxy_port)
				);
		case FAILED:
			break;
		default:
			NEVER();
		
	}
	CheckConnection();
	return InProgress();
}

NetworkProxy::NetworkProxy()
{
	socket = NULL;	
	packet_length = 0;
	proxy_type = 0;
	proxy_port = Null;
	target_port = Null;
	timeout = 120000;
	error = 0;
	ssl	= false;
	SetPhase(START);
}

NetworkProxy::~NetworkProxy()
{ 
	socket = NULL;
	packet.Clear();
}

bool HttpProxy::IsEof()
{
	if(packet.GetCount() > 3) {
		const char *c = packet.Last();
		if(c[-2] == '\n' && c[-1] == '\r' && c[0] == '\n')
			return true;
	}
	return false;
}

void HttpProxy::Start()
{
	NetworkProxy::Init();
	if(IsFailure())
		return;
	LLOG(Format("** Connecting to HTTP proxy server. (%s:%d)", proxy_host, proxy_port));
	StartDns(proxy_host, proxy_port);
}

void HttpProxy::DoConnect()
{
	if(!socket->Connect(ipaddrinfo))
		return;
	LLOG(Format("++ Successfully connected to HTTP proxy server. (%s:%d)", 
		proxy_host, proxy_port));  
	ipaddrinfo.Clear();
	MakeRequest();
}

void HttpProxy::MakeRequest()
{
	packet.Clear();
	packet_length = 0;
	int port = Nvl(target_port, ssl ? 443 : 80);
	packet << "CONNECT " << target_host << ":" <<  port << " HTTP/1.1\r\n"
           << "Host: " << target_host << ":" << port << "\r\n";
	if(!proxy_user.IsEmpty() && !proxy_password.IsEmpty())
		packet << "Proxy-Authorization: Basic " << Base64Encode(proxy_user + ":" + proxy_password) << "\r\n";
    packet << "\r\n";
	LLOG(">> Sending HTTP_CONNECT request.");
	LDUMPHEX(packet);
	SetPhase(REQUEST);
}

void HttpProxy::ParseReply()
{
	LLOG("<< HTTP_CONNECT request reply received.");
	LDUMPHEX(packet);
	int q = min(packet.Find('\r'), packet.Find('\n'));
	if(q >= 0)
		packet.Trim(q);
	if(!packet.StartsWith("HTTP") || packet.Find(" 2") < 0) {
		SetError(HTTP_CONNECT_FAILED);
		error_desc << " " << packet;
	}
	else 
		SetPhase(ssl ? STARTSSL : FINISHED);
}

HttpProxy::HttpProxy(TcpSocket& sock)
{
	socket = &sock;
	proxy_type = HTTP;
}

bool SocksProxy::IsEof()
{
	int result = false;
	switch(packet_type) {
		case SOCKS5_HELLO:
		case SOCKS5_AUTH: {
			result = packet.GetCharCount() == 2;
			break;
		}
		case SOCKS_REQUEST: {
			if(proxy_type == SOCKS4)  {
				result = packet.GetCharCount() == 8;
				break;
			}
			else
			if(proxy_type == SOCKS5) {
				if(packet.GetCharCount() == 5) {
					const char *address_type = packet.Last() - 1;
					if(*address_type == 0x01)
						packet_length = 10;
					else
					if(*address_type == 0x03)
						packet_length = 7 + *packet.Last();
					else
					if(*address_type == 0x04)
						packet_length = 22;
					break;
				}
				result = packet.GetCharCount() == packet_length;
				break;	
			}
		}
		default:
			NEVER();
	}
	return result;
}

void SocksProxy::DoConnect()
{
	if(!socket->Connect(ipaddrinfo))
		return;
	LLOG(Format("++ Successfully connected to SOCKS%d proxy server. (%s:%d)", 
		proxy_type, proxy_host, proxy_port));  
	ipaddrinfo.Clear();
	if(!dns_lookup) {
		NextStep = THISBACK(MakeRequest);
		StartDns(target_host, target_port);
	}
	else
		MakeRequest();
}

void SocksProxy::Start()
{
	NetworkProxy::Init();
	if(IsFailure())
		return;
	bound = false;
	memset(&bound_addr, 0, sizeof(sockaddr_storage));
	packet_type = proxy_type == SOCKS4 ? SOCKS_REQUEST : SOCKS5_HELLO;
	LLOG(Format("** Connecting to SOCKS%d proxy server. (%s:%d)", proxy_type, proxy_host, proxy_port));
	StartDns(proxy_host, proxy_port);
}

void SocksProxy::Socks4Request()
{	
	packet.Clear();
	packet_length = 0;
	packet.Cat(0x04);
	packet.Cat(command);
	if(dns_lookup) {
		uint16 port = htons(target_port);
		uint32 addr = htonl(0x00000001);
		packet.Cat((const char*) &port, sizeof(uint16));
		packet.Cat((const char*) &addr, sizeof(uint32));
	}
	else {
		struct addrinfo *info = ipaddrinfo.GetResult();
		if(info->ai_family == AF_INET6) {
			SetError(SOCKS4_ADDRESS_TYPE_NOT_SUPPORTED);
			return;
		}
		sockaddr_in *target = (sockaddr_in*) info->ai_addr;
		packet.Cat((const char*) &target->sin_port, sizeof(uint16));
		packet.Cat((const char*) &target->sin_addr.s_addr, sizeof(uint32));
		ipaddrinfo.Clear();
	}
	if(!proxy_user.IsEmpty())
		packet.Cat(proxy_user);
	packet.Cat(0x00);
	if(dns_lookup) {
		packet.Cat(target_host);
		packet.Cat(0x00);
	}
	LLOG(">> Sending SOCKS4 command request.");
	LDUMPHEX(packet);
	SetPhase(REQUEST);	
}

void SocksProxy::Socks5Request()
{
	packet.Clear();
	packet_length = 0;
	if(packet_type == SOCKS5_HELLO) {
		packet.Cat(0x05);
		packet.Cat(0x02);
		packet.Cat(0x00);
		packet.Cat(0x02);
		LLOG(">> Sending SOCKS5 initial greetings.");
	}
	else
	if(packet_type == SOCKS5_AUTH) {
		packet.Cat(0x01);
		packet.Cat(proxy_user.GetLength());
		packet.Cat(proxy_user);
		packet.Cat(proxy_password.GetLength());
		packet.Cat(proxy_password);
		LLOG(">> Sending SOCKS5 authorization request.");
	}
	else
	if(packet_type == SOCKS_REQUEST) {
		packet.Cat(0x05);
		packet.Cat(command);
		packet.Cat(0x00);
		if(dns_lookup) {
			packet.Cat(0x03);
			packet.Cat(target_host.GetLength());
			packet.Cat(target_host);
			int port = 	htons(target_port);
			packet.Cat((const char*) &port, 2);
		}
		else {
			struct addrinfo *info = ipaddrinfo.GetResult();
			if(info->ai_family == AF_INET) {
				sockaddr_in *target = (sockaddr_in*) info->ai_addr; 
				packet.Cat(0x01);
				packet.Cat((const char*) &target->sin_addr.s_addr, 4);			
				packet.Cat((const char*) &target->sin_port, 2);
			}
			else
			if(info->ai_family == AF_INET6) {
				sockaddr_in6 *target = (sockaddr_in6*) info->ai_addr; 
				packet.Cat(0x04);
				packet.Cat((const char*) &target->sin6_addr.s6_addr, 16);			
				packet.Cat((const char*) &target->sin6_port, 2);			
			}
			ipaddrinfo.Clear();
		}
		LLOG(">> Sending SOCKS5 command request.");
	}
	LDUMPHEX(packet);
	SetPhase(REQUEST);
}

void SocksProxy::ParseReply()
{
	const char *version	= packet.Begin();

	if(packet_type == SOCKS5_HELLO) {
		LLOG("<< SOCKS5 server greetings received.");
		LDUMPHEX(packet);
		const char *method = packet.Last();
		if(*version != 0x05) {
			SetError(INVALID_PACKET);
			return;
		}
		if(*method == 0x00)
			packet_type = SOCKS_REQUEST;
		else
		if(*method == 0x02)
			packet_type = SOCKS5_AUTH;
		else {
			SetError(SOCKS5_INVALID_AUTHENTICATION_METHOD);
			return;
		}
		MakeRequest();
		return;
		
	}
	else
	if(packet_type == SOCKS5_AUTH) {
		LLOG("<< SOCKS5 authorization reply received.");
		LDUMPHEX(packet);
		const char *status = packet.Last();
		if(*version != 0x01) {
			SetError(INVALID_PACKET);
			return;
		}
		if(*status != 0x00) {
			SetError(SOCKS5_AUTHENTICATION_FAILED);
			return;
		}
		packet_type = SOCKS_REQUEST;
		MakeRequest();
		return;		
		
	}
	else
	if(packet_type == SOCKS_REQUEST) {
		char ver  = proxy_type == SOCKS4 ? 0x00 : 0x05;
		char stat = proxy_type == SOCKS4 ? 0x5a : 0x00;
		LLOG(Format("<< SOCKS%d command request reply received.", proxy_type));
		LDUMPHEX(packet);
		if(*version != ver) {
			SetError(INVALID_PACKET);
			return;
		}
		const char *status = packet.Begin() + 1;
		if(*status != stat) {
			SetError(*status);
			return;
		}
		switch(command)	{
			case ACCEPT:
				if(!bound) {
					bound = true;
					ParseBoundAddr();
					SetPhase(BOUND);
					packet.Clear();
					packet_length = 0;
					if(WhenBound)
						WhenBound(*this);
					LLOG(Format("++ SOCKS%d Bind() successful.", proxy_type));					
					break;
				}
				LLOG(Format("++ Socket accepted.", proxy_type));
				SetPhase(FINISHED);
				break;
			case CONNECT:
				SetPhase(ssl ? STARTSSL : FINISHED);
				break;		
		}
	}
}

void SocksProxy::ParseBoundAddr()
{
	switch(proxy_type) {
		case SOCKS4: {
			struct sockaddr_in *in = (sockaddr_in*) &bound_addr;
			in->sin_family = AF_INET;
			memcpy(&in->sin_port, (const char*) packet.Begin() + 2, sizeof(uint16));
			memcpy(&in->sin_addr.s_addr, (const char*) packet.Begin() + 4, sizeof(uint32));
			break;
		}
		case SOCKS5: {
			const char family = *(packet.Begin() + 3);
			switch(family) {
				case 0x01: {
					struct sockaddr_in *in = (sockaddr_in*) &bound_addr;
					in->sin_family = AF_INET;
					memcpy(&in->sin_port, (const char*) packet.Last() - (uint16) 1, sizeof(uint16));
					memcpy(&in->sin_addr.s_addr, (const char*) packet.Begin() + 4, sizeof(uint32));
					break;
				}
				case 0x04: {
					struct sockaddr_in6 *in6 = (sockaddr_in6*) &bound_addr;
					in6->sin6_family = AF_INET6;
					memcpy(&in6->sin6_port, (const char*) packet.Last() - (uint16) 1, sizeof(uint16));
					memcpy(&in6->sin6_addr.s6_addr, (const char*) packet.Begin() + 4, sizeof(char) * 16);
					break;
				}
				case 0x03:
					break;
				default:
					NEVER();
			}
		}
	}
}

bool SocksProxy::Accept(const String& host, int port)
{
	StartAccept(host, port);
	while(Do())
		;
	return IsSuccess();
}

void SocksProxy::StartAccept(const String& host, int port)
{
	target_host = host;
	target_port = port;
	command = ACCEPT;
	SetPhase(START);
}

bool SocksProxy::Do()
{ 
	if(phase == BOUND) {
		SetPhase(REPLY);
		return true;
	}
	else
		return NetworkProxy::Do();
}

Tuple2<String, int> SocksProxy::GetBoundAddr()
{
	Buffer<char> ip_buffer(16, 0), dummy(16, 0);
	String ip;
	int port = 0;
	size_t sa_size = 0, ip_size = 0;
	
	if(bound_addr.ss_family == AF_INET) {
		sockaddr_in *in = (sockaddr_in*) &bound_addr;	
		memcpy(ip_buffer, &in->sin_addr.s_addr, sizeof(in_addr));
		sa_size = sizeof(sockaddr_in);
		ip_size = 4;
		port = ntohs(in->sin_port);
	}
	else
	if(bound_addr.ss_family == AF_INET6) {
		sockaddr_in6 *in6 = (sockaddr_in6*) &bound_addr;
		memcpy(ip_buffer, &in6->sin6_addr.s6_addr, sizeof(in6_addr));
		sa_size = sizeof(sockaddr_in6);
		ip_size = 16;
		port = ntohs(in6->sin6_port);
	}
	if(socket && memcmp(ip_buffer, dummy, ip_size) == 0)
			return MakeTuple<String, int>(socket->GetPeerAddr(), port);
	ip_size = 64;
#ifdef PLATFORM_WIN32
	ip_buffer.Alloc(ip_size, 0);
	// inet_ntop function is not available on windows XP or 2003;
	// A similar function (WSAAddressToString) was added with XP SP1. 
	if(WSAAddressToString((sockaddr*) &bound_addr, sa_size, 0, (LPSTR) ip_buffer, (LPDWORD) &ip_size) == 0)
		ip = ip_buffer;	
#else
	dummy.Alloc(ip_size, 0);
	ip = inet_ntop(bound_addr.ss_family, ip_buffer, dummy, 64);
#endif
	return MakeTuple<String, int>(ip, port);	
}

SocksProxy::SocksProxy()
{
	proxy_type = SOCKS4;
	dns_lookup = false;	
}

SocksProxy::SocksProxy(TcpSocket& sock)
{
	socket = &sock;
	proxy_type = SOCKS4;
	dns_lookup = false;
}

int ProxyConnect(int type, TcpSocket& socket, const String& proxy_host, int proxy_port, 
	const String& target_host, int target_port, const String& user, const String& password, 
	int timeout, bool ssl)
{
	One<NetworkProxy> proxy;

	switch(type) {
		case NetworkProxy::HTTP:
			proxy.Create<HttpProxy>().Auth(user, password);
			break;
		case NetworkProxy::SOCKS4: 
			proxy.Create<SocksProxy>().Auth(user).Socks4();
			break;
		case NetworkProxy::SOCKS5: 
			proxy.Create<SocksProxy>().Auth(user, password).Socks5();
			break;
		default:
			return -1;
	}
	
	proxy->Attach(socket)
		.Host(proxy_host)
		.Port(proxy_port)
		.Timeout(timeout)
		.SSL(ssl);
		
	return proxy->Connect(target_host, target_port) ? 0 : proxy->GetError();
}

int ProxyAccept(int type, TcpSocket& socket, const String& proxy_host, int proxy_port, 
	const String& target_host, int target_port, Callback1<SocksProxy&> whenbound, 
	const String& user, const String& password, int timeout)
{
	One<SocksProxy> proxy;

	switch(type) {
		case NetworkProxy::SOCKS4: 
			proxy.Create().Socks4();
			break;
		case NetworkProxy::SOCKS5: 
			proxy.Create().Socks5();
			break;
		default:
			return -1;
	}
	
	proxy->Auth(user, password)
		.Attach(socket)
		.Host(proxy_host)
		.Port(proxy_port)
		.Timeout(timeout);
	proxy->WhenBound = whenbound;
	
	return proxy->Accept(target_host, target_port) ? 0 : proxy->GetError();
}

END_UPP_NAMESPACE